import numpy as np
import torch as t
import skimage.io as io

def inverseSrgb(v):
  v=np.clip(v,0,1)
  return np.where(v <= 0.04045, v / 12.92, ((v + 0.055) / 1.055) ** 2.4)

def srgb(v):
  v=np.clip(v,0,1)
  return np.where(v <= 0.0031308, 12.92 * v, 1.055 * (v ** (1/2.4)) - 0.055)

def inverseSrgb_(v):
  v=t.clamp(v,0,1)
  return t.where(v <= 0.04045, v / 12.92, ((v + 0.055) / 1.055) ** 2.4)

def srgb_(v):
  v=t.clip(v,0,1)
  return t.where(v <= 0.0031308, 12.92 * v, 1.055 * (v ** (1/2.4)) - 0.055)

def fspecialgaussian(radius, dev):
  a = t.empty((radius*2+1,radius*2+1))
  a = a.detach()
  for i in range(radius*2+1):
    for j in range(radius*2+1):
      a[i,j] = np.exp(-((i-radius)**2+(j-radius)**2)/(dev*dev*2))
  return a/t.sum(a)

def tolum(img):
    weights = np.array([0.2126, 0.7152, 0.0722])
    return np.dot(img[..., :3], weights)

def loadKernSrgbGS(path, badstr):
  ker = io.imread(path).astype(np.float32)/255
  ker = t.from_numpy(tolum(inverseSrgb(ker))).to(t.float32)
  h=(ker.shape[0]-1)//2 
  w=(ker.shape[1]-1)//2
  print(h,w)
  idker = t.zeros((h*2+1,w*2+1))
  idker[h,w]=1
  return ker*badstr/t.sum(ker)+(1-badstr)*idker,h,w

def composeGrayscale(a,b,c,d, inpd=10, opd=5, padcolor=0.3):
  w=a.shape[1]
  h=a.shape[0]
  out = t.zeros((h*2+opd*2+inpd,w*2+opd*2+inpd))+padcolor
  out[opd:h+opd,                opd:w+opd] = a
  out[opd+h+inpd:opd+h*2+inpd,  opd:w+opd] = c
  out[opd:h+opd,                opd+w+inpd:opd+w*2+inpd] = b
  out[opd+h+inpd:opd+h*2+inpd,  opd+w+inpd:opd+w*2+inpd] = d
  return out

def grayscaleShow(x,y,distortfn, inpd=10, opd=5, padcolor=0.3):
  x_=distortfn(x).detach().to("cpu")
  y_=distortfn(y).detach().to("cpu")
  x=x.detach().to("cpu")
  y=y.detach().to("cpu")
  return composeGrayscale(srgb_(x),srgb_(x_),srgb_(y),srgb_(y_),inpd,opd,padcolor)

def rgbShow(x,y,distortfn):
  l0 = grayscaleShow(x[:,:,0],y[:,:,0],distortfn)
  l1 = grayscaleShow(x[:,:,1],y[:,:,1],distortfn)
  l2 = grayscaleShow(x[:,:,2],y[:,:,2],distortfn)
  return t.stack([l0,l1,l2],dim=-1)

def torgb(x):
  return t.stack([x,x,x],dim=-1)

if __name__ == "__main__":
  print(fspecialgaussian(5,2.5))